{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gym\n",
    "import gym_custom_envs\n",
    "import time\n",
    "import random\n",
    "import scipy.linalg as alg\n",
    "from gym import wrappers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "i = 0\n",
    "env = gym.make(\"Quadrotor2D-v0\")\n",
    "env = gym.wrappers.Monitor(env, \"vid\", video_callable=lambda i: True,force=True)\n",
    "\n",
    "# uncomment for testing environment below\n",
    "# for i in range(5):\n",
    "#     s = env.reset()\n",
    "    \n",
    "#     done = False\n",
    "#     while not done:\n",
    "#         a = np.array(np.random.uniform(low=-2.0, high=2.0, size=(2,)))\n",
    "# #         a = np.zeros(2)\n",
    "#         print(round(s[0], 3), round(s[1], 3), round(np.rad2deg(s[2]), 3))\n",
    "#         ns, r, done, _ = env.step(a)\n",
    "#         s = ns\n",
    "#         env.render()\n",
    "#         time.sleep(0.08)\n",
    "#         clear_output(wait=True)\n",
    "#env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M(q) = \n",
      "[[0.2        0.         0.        ]\n",
      " [0.         0.2        0.        ]\n",
      " [0.         0.         0.00416667]]\n",
      "b(q) = \n",
      "[[ 0.    0.  ]\n",
      " [ 1.    1.  ]\n",
      " [ 0.25 -0.25]]\n",
      "iM(q) = \n",
      "[[  5.   0.   0.]\n",
      " [  0.   5.   0.]\n",
      " [  0.   0. 240.]]\n"
     ]
    }
   ],
   "source": [
    "m = env.m\n",
    "I = env.I\n",
    "g = env.gravity\n",
    "r = env.l/2\n",
    "\n",
    "# M(q) & B(q) since C(q, q_dot) and Tau(q) is zero\n",
    "M = np.matrix([\n",
    "    [m, 0, 0],\n",
    "    [0, m, 0],\n",
    "    [0, 0, I],\n",
    "])\n",
    "\n",
    "b = np.matrix([\n",
    "    np.zeros(2),\n",
    "    [1, 1],\n",
    "    [r, -r]\n",
    "])\n",
    "\n",
    "iM = np.linalg.inv(M)\n",
    "\n",
    "print('M(q) = ')\n",
    "print(M)\n",
    "\n",
    "print('b(q) = ')\n",
    "print(b)\n",
    "\n",
    "print('iM(q) = ')\n",
    "print(iM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lqr(A, B, Q, R):\n",
    "\n",
    "    S = np.matrix(alg.solve_continuous_are(A, B, Q, R))\n",
    "    \n",
    "    K = np.matrix(alg.inv(R)*B.T*S)\n",
    "        \n",
    "    return K, S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A = \n",
      "[[  0.   0.   0.   1.   0.   0.]\n",
      " [  0.   0.   0.   0.   1.   0.]\n",
      " [  0.   0.   0.   0.   0.   1.]\n",
      " [  0.   0. -10.   0.   0.   0.]\n",
      " [  0.   0.   0.   0.   0.   0.]\n",
      " [  0.   0.   0.   0.   0.   0.]]\n",
      "B = \n",
      "[[  0.   0.]\n",
      " [  0.   0.]\n",
      " [  0.   0.]\n",
      " [  0.   0.]\n",
      " [  5.   5.]\n",
      " [ 60. -60.]]\n",
      "Q = \n",
      "[[10.          0.          0.          0.          0.          0.        ]\n",
      " [ 0.         10.          0.          0.          0.          0.        ]\n",
      " [ 0.          0.         10.          0.          0.          0.        ]\n",
      " [ 0.          0.          0.          1.          0.          0.        ]\n",
      " [ 0.          0.          0.          0.          1.          0.        ]\n",
      " [ 0.          0.          0.          0.          0.          0.07957747]]\n",
      "R = \n",
      "[[0.02 0.  ]\n",
      " [0.   0.02]]\n",
      "K = \n",
      "[[-15.8113883   15.8113883   23.56497713  -9.97592105   5.30681427\n",
      "    1.54343332]\n",
      " [ 15.8113883   15.8113883  -23.56497713   9.97592105   5.30681427\n",
      "   -1.54343332]]\n"
     ]
    }
   ],
   "source": [
    "db = np.matrix([\n",
    "    [0 , 0, -m*g],\n",
    "    np.zeros(3),\n",
    "    np.zeros(3)\n",
    "    ])\n",
    "\n",
    "A = np.concatenate((\n",
    "    np.concatenate((np.zeros((3, 3)), np.eye(3)), axis=1),\n",
    "    np.concatenate((np.dot(iM, db), np.zeros((3, 3))), axis=1)\n",
    "), axis=0)\n",
    "\n",
    "B = np.concatenate((\n",
    "    np.zeros((3, 2)),\n",
    "    np.dot(iM, b)\n",
    "), axis=0)\n",
    "\n",
    "Q = np.diag([10, 10, 10, 1, 1, (r/ np.pi)])\n",
    "\n",
    "R = np.array([[0.02, 0], [0, 0.02]])\n",
    "\n",
    "K, S = lqr(A, B, Q, R)\n",
    "\n",
    "print('A = ')\n",
    "print(A)\n",
    "print('B = ')\n",
    "print(B)\n",
    "\n",
    "print('Q = ')\n",
    "print(Q)\n",
    "\n",
    "print('R = ')\n",
    "print(R)\n",
    "\n",
    "print('K = ')\n",
    "print(K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    s = env.reset().reshape(1, 6)\n",
    "    \n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "#         K, S = lqr(A, B, Q, R)\n",
    "        \n",
    "        a = -np.dot(K, s.T)\n",
    "\n",
    "#         print('Theta = ', round(np.rad2deg(s[0, 2]), 3))\n",
    "#         print('u = ', a.T)\n",
    "#         print(\"Reaction = \", round((a[0] + a[1])[0, 0] - m*g, 3))\n",
    "#         print(\"Cost = \", round(np.dot(s, np.dot(S, s.T))[0, 0], 3))\n",
    "\n",
    "        ns, r, done, _ = env.step([a[0], a[1]])\n",
    "        s = ns.reshape(1, 6)\n",
    "        env.render()\n",
    "        \n",
    "#         clear_output(wait=True)\n",
    "#         time.sleep(0.008)\n",
    "# env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
